{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentiment analysis using PyTorch\n",
    "\n",
    "In this example, we'll use PyTorch to implement a LSTM-based sentiment analysis over the [Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/). The model will take as input a text sequence, which represents a movie review and will output a binary result of whether the review is positive or negative.\n",
    "\n",
    "_This example is partially based on_ [https://github.com/bentrevett/pytorch-sentiment-analysis](https://github.com/bentrevett/pytorch-sentiment-analysis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports and the configuration. We'll use the `torchtext` package to load the dataset and tokenize the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchtext\n",
    "\n",
    "EMBEDDING_SIZE = 100\n",
    "HIDDEN_SIZE = 256"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's load the dataset, which is embedded in `torchtext`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up fields\n",
    "TEXT = torchtext.data.Field(\n",
    "    tokenize='spacy',  # use SpaCy tokenizer\n",
    "    lower=True,  # convert all letters to lower case\n",
    "    include_lengths=True,  # include the length of the movie review\n",
    ")\n",
    "\n",
    "LABEL = torchtext.data.LabelField(dtype=torch.float)\n",
    "\n",
    "# Dataset splits\n",
    "train, test = torchtext.datasets.IMDB.splits(TEXT, LABEL)\n",
    "\n",
    "# Build glove vocabulary\n",
    "TEXT.build_vocab(train, vectors=torchtext.vocab.GloVe(name='6B', dim=100))\n",
    "LABEL.build_vocab(train)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# make iterator for splits\n",
    "train_iter, test_iter = torchtext.data.BucketIterator.splits(\n",
    "    (train, test), sort_within_batch=True, batch_size=64, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, wel'll implement the `LSTMModel` class, which uses `torch.nn.LSTM` at its core. `LSTMModel` combines the LSTM cell with an embedding input layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMModel(torch.nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_size, hidden_size, output_size, pad_idx):\n",
    "        super().__init__()\n",
    "\n",
    "        # Embedding field\n",
    "        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size,\n",
    "                                            embedding_dim=embedding_size,\n",
    "                                            padding_idx=pad_idx)\n",
    "\n",
    "        # LSTM cell\n",
    "        self.rnn = torch.nn.LSTM(input_size=embedding_size, hidden_size=hidden_size)\n",
    "\n",
    "        # Fully connected output\n",
    "        self.fc = torch.nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, text_sequence, text_lengths):\n",
    "        # Extract embedding vectors\n",
    "        embeddings = self.embedding(text_sequence)\n",
    "\n",
    "        # Pad the sequences to equal length\n",
    "        packed_sequence = torch.nn.utils.rnn.pack_padded_sequence(embeddings, text_lengths)\n",
    "\n",
    "        packed_output, (hidden, cell) = self.rnn(packed_sequence)\n",
    "\n",
    "        return self.fc(hidden)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll instantiate the model and we'll initialize it's weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LSTMModel(vocab_size=len(TEXT.vocab),\n",
    "                  embedding_size=EMBEDDING_SIZE,\n",
    "                  hidden_size=HIDDEN_SIZE,\n",
    "                  output_size=1,\n",
    "                  pad_idx=TEXT.vocab.stoi[TEXT.pad_token])\n",
    "\n",
    "model.embedding.weight.data.copy_(TEXT.vocab.vectors)\n",
    "\n",
    "model.embedding.weight.data[TEXT.vocab.stoi[TEXT.unk_token]] = torch.zeros(EMBEDDING_SIZE)\n",
    "model.embedding.weight.data[TEXT.vocab.stoi[TEXT.pad_token]] = torch.zeros(EMBEDDING_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll implement the training procedure, which is generic and works for feed-forward networks as wellL:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, loss_function, optimizer, data_loader):\n",
    "    # set model to training mode\n",
    "    model.train()\n",
    "\n",
    "    current_loss = 0.0\n",
    "    current_acc = 0\n",
    "\n",
    "    # iterate over the training data\n",
    "    for i, batch in enumerate(data_loader):\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        text, text_lengths = batch.text\n",
    "\n",
    "        with torch.set_grad_enabled(True):\n",
    "            # forward\n",
    "            outputs = model(text, text_lengths).squeeze()\n",
    "            loss = loss_function(outputs, batch.label)\n",
    "\n",
    "            # backward\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        # statistics\n",
    "        current_loss += loss.item() * text_lengths.shape[0]\n",
    "        current_acc += torch.sum(torch.round(torch.sigmoid(outputs)).round() == batch.label.data)\n",
    "\n",
    "    total_loss = current_loss / len(data_loader.dataset)\n",
    "    total_acc = current_acc.double() / len(data_loader.dataset)\n",
    "\n",
    "    print('Train Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss, total_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the testing procedure, which is also generic:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(model, loss_function, data_loader):\n",
    "    # set model in evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    current_loss = 0.0\n",
    "    current_acc = 0\n",
    "\n",
    "    # iterate over  the validation data\n",
    "    for i, batch in enumerate(data_loader):\n",
    "        text, text_lengths = batch.text\n",
    "\n",
    "        # forward\n",
    "        with torch.set_grad_enabled(False):\n",
    "            outputs = model(text, text_lengths).squeeze()\n",
    "            loss = loss_function(outputs, batch.label)\n",
    "\n",
    "        # statistics\n",
    "        current_loss += loss.item() * text_lengths.shape[0]\n",
    "        current_acc += torch.sum(torch.round(torch.sigmoid(outputs)).round() == batch.label.data)\n",
    "\n",
    "    total_loss = current_loss / len(data_loader.dataset)\n",
    "    total_acc = current_acc.double() / len(data_loader.dataset)\n",
    "\n",
    "    print('Test Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss, total_acc))\n",
    "\n",
    "    return total_loss, total_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we'll instantiate the training components and train the model for 5 epochs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "Train Loss: 0.6390; Accuracy: 0.6390\n",
      "Test Loss: 0.6100; Accuracy: 0.6719\n",
      "Epoch 2/5\n",
      "Train Loss: 0.4262; Accuracy: 0.8076\n",
      "Test Loss: 0.3940; Accuracy: 0.8248\n",
      "Epoch 3/5\n",
      "Train Loss: 0.2123; Accuracy: 0.9215\n",
      "Test Loss: 0.3123; Accuracy: 0.8758\n",
      "Epoch 4/5\n",
      "Train Loss: 0.0855; Accuracy: 0.9736\n",
      "Test Loss: 0.4015; Accuracy: 0.8571\n",
      "Epoch 5/5\n",
      "Train Loss: 0.0334; Accuracy: 0.9916\n",
      "Test Loss: 0.5525; Accuracy: 0.8531\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "loss_function = torch.nn.BCEWithLogitsLoss().to(device)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(5):\n",
    "    print(f\"Epoch {epoch + 1}/5\")\n",
    "    train_model(model, loss_function, optimizer, train_iter)\n",
    "    test_model(model, loss_function, test_iter)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
